{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "FPjTnCG-v2Z6",
        "outputId": "a7da634c-78c7-4102-cd6e-70e47714a1c2"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "zsh:1: command not found: nvidia-smi\n"
          ]
        }
      ],
      "source": [
        "\"\"\"\n",
        "Notebook for training the embedding model for the Lorenz system.\n",
        "=====\n",
        "Distributed by: Notre Dame SCAI Lab (MIT Liscense)\n",
        "- Associated publication:\n",
        "url: https://arxiv.org/abs/2010.03957\n",
        "doi: \n",
        "github: https://github.com/zabaras/transformer-physx\n",
        "=====\n",
        "\"\"\"\n",
        "!nvidia-smi"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "34XMtg9FZFql"
      },
      "source": [
        "# Environment Setup"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TuqsYQ_2S4kq"
      },
      "source": [
        "Use pip to install from [PyPI](https://pypi.org/project/trphysx/)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "9zaXL-m8xEx9",
        "outputId": "a6a686ea-1f21-4a02-c32e-af294db14ebd"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[1m\u001b[36mdata\u001b[m\u001b[m/                           train_lorenz_enn.py\n",
            "\u001b[1m\u001b[36moutputs\u001b[m\u001b[m/                        train_lorenz_transformer.ipynb\n",
            "predict_lorenz.py               train_lorenz_transformer.py\n",
            "train_lorenz_enn.ipynb\n"
          ]
        }
      ],
      "source": [
        "%ls"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6MK_h8wF0Rr4"
      },
      "source": [
        "Now lets download the training and validation data for the lorenz system. Info on wget from [Google drive](https://stackoverflow.com/questions/37453841/download-a-file-from-google-drive-using-wget). This will eventually be update to zenodo repo."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "2NtZ02zD0EKo",
        "outputId": "3ca992ae-2885-4e2f-8dd2-b73ee905290f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "mkdir: data: File exists\n"
          ]
        }
      ],
      "source": [
        "!mkdir data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dCHiYnrdZN95"
      },
      "source": [
        "# Transformer-PhysX Lorenz System"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YDVeeJHn11Ir"
      },
      "source": [
        "Train the embedding model.\n",
        "First import necessary modules from trphysx. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "dNGVZQ-o1gsH"
      },
      "outputs": [],
      "source": [
        "import sys, os\n",
        "import logging\n",
        "\n",
        "import paddle\n",
        "from paddle.optimizer.lr import ExponentialDecay\n",
        "\n",
        "from trphysx.config.configuration_auto import AutoPhysConfig\n",
        "from trphysx.embedding.embedding_auto import AutoEmbeddingModel\n",
        "from trphysx.embedding.training import *"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YEft0ltg4swx"
      },
      "source": [
        "Training arguments."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "8R8QQ0cj4qR9"
      },
      "outputs": [],
      "source": [
        "argv = []\n",
        "argv = argv + [\"--exp_name\", \"lorenz\"]\n",
        "argv = argv + [\"--training_h5_file\", \"./data/lorenz_training_rk.hdf5\"]\n",
        "argv = argv + [\"--eval_h5_file\", \"./data/lorenz_valid_rk.hdf5\"]\n",
        "argv = argv + [\"--batch_size\", '512']\n",
        "argv = argv + [\"--block_size\", \"16\"]\n",
        "argv = argv + [\"--n_train\", \"2048\"]\n",
        "argv = argv + [\"--n_eval\", \"64\"]\n",
        "argv = argv + [\"--epochs\", \"300\"]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "wcBfHEh540f6",
        "outputId": "a9dc0eeb-013c-4825-9f41-833bc8d242f7"
      },
      "outputs": [],
      "source": [
        "args = EmbeddingParser().parse(args=argv)  \n",
        "# Setup logging\n",
        "logging.basicConfig(\n",
        "    format=\"%(asctime)s - %(levelname)s - %(name)s -   %(message)s\",\n",
        "    datefmt=\"%m/%d/%Y %H:%M:%S\",\n",
        "    level=logging.INFO)\n",
        "\n",
        "use_cuda = False\n",
        "if paddle.is_compiled_with_cuda():\n",
        "    use_cuda = True\n",
        "args.device = paddle.set_device(\"gpu:0\" if use_cuda else \"cpu\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "T9Fel9vqZVsH"
      },
      "source": [
        "## Initializing Datasets and Models"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JjxiqD3lF98_"
      },
      "source": [
        "Now we can use the auto classes to initialized the predefined configs, dataloaders and models. This may take a bit!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "EcpC9Fy243RN",
        "outputId": "3b629e8f-b0cb-405f-fd74-ac992c47d92c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "embedding_model.kMatrixDiag paddle.float32 False\n",
            "embedding_model.kMatrixUT paddle.float32 False\n",
            "embedding_model.observableNet.0.weight paddle.float32 False\n",
            "embedding_model.observableNet.0.bias paddle.float32 False\n",
            "embedding_model.observableNet.2.weight paddle.float32 False\n",
            "embedding_model.observableNet.2.bias paddle.float32 False\n",
            "embedding_model.observableNet.3.weight paddle.float32 False\n",
            "embedding_model.observableNet.3.bias paddle.float32 False\n",
            "embedding_model.recoveryNet.0.weight paddle.float32 False\n",
            "embedding_model.recoveryNet.0.bias paddle.float32 False\n",
            "embedding_model.recoveryNet.2.weight paddle.float32 False\n",
            "embedding_model.recoveryNet.2.bias paddle.float32 False\n"
          ]
        }
      ],
      "source": [
        " # Load transformer config file\n",
        "config = AutoPhysConfig.load_config(args.exp_name)\n",
        "dataloader = AutoDataHandler.load_data_handler(args.exp_name)\n",
        "\n",
        "# Set up data-loaders\n",
        "training_loader = dataloader.createTrainingLoader(\n",
        "    args.training_h5_file, \n",
        "    block_size=args.block_size, \n",
        "    stride=args.stride, \n",
        "    ndata=args.n_train, \n",
        "    batch_size=args.batch_size)\n",
        "testing_loader = dataloader.createTestingLoader(\n",
        "    args.eval_h5_file, \n",
        "    block_size=32, \n",
        "    ndata=args.n_eval, \n",
        "    batch_size=8)\n",
        "\n",
        "# Set up model\n",
        "model = AutoEmbeddingModel.init_trainer(args.exp_name, config).to(args.device)\n",
        "if args.epoch_start > 1:\n",
        "  model.load_model(args.ckpt_dir, args.epoch_start)\n",
        "\n",
        "count = 0\n",
        "for name, param in model.named_parameters():\n",
        "    print(name, param.dtype,param.stop_gradient)\n",
        "    count += param.numel()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "auDiMVZ5UfNz"
      },
      "source": [
        "Initialize optimizer and scheduler. Feel free to change if you want to experiment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "bnrtuKdhGuWQ"
      },
      "outputs": [],
      "source": [
        "scheduler = ExponentialDecay(\n",
        "    learning_rate=args.lr * 0.995 ** (args.epoch_start - 1), gamma=0.995\n",
        ")\n",
        "\n",
        "optimizer = paddle.optimizer.Adam(\n",
        "    parameters=model.parameters(),\n",
        "    learning_rate=scheduler,\n",
        "    weight_decay=1e-8,\n",
        "    grad_clip=paddle.nn.ClipGradByGlobalNorm(0.1),\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "It_LLGQIZe0b"
      },
      "source": [
        "## Training the Embedding Model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U2XPKfYTUuXf"
      },
      "source": [
        "Train the model. No visualization here, just boring numbers. This notebook only trains for 100 epochs for brevity, feel free to train longer. The test loss here is only the recovery loss MSE(x - decode(encode(x))) and does not reflect the quality of the Koopman dynamics."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "2ic9DFQcUWpm",
        "outputId": "7532fecf-c988-45b2-a664-b91cf4bef7cb"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/Users/puqing/opt/miniconda3/lib/python3.9/site-packages/paddle/fluid/dygraph/math_op_patch.py:275: UserWarning: The dtype of left and right variables are not the same, left dtype is paddle.float64, but right dtype is paddle.float32, the right dtype will convert to paddle.float64\n",
            "  warnings.warn(\n",
            "/Users/puqing/opt/miniconda3/lib/python3.9/site-packages/astroid/node_classes.py:90: DeprecationWarning: The 'astroid.node_classes' module is deprecated and will be replaced by 'astroid.nodes' in astroid 3.0.0\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "ename": "KeyboardInterrupt",
          "evalue": "",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
            "Cell \u001b[0;32mIn [6], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m trainer \u001b[39m=\u001b[39m EmbeddingTrainer(model, args, (optimizer, scheduler))\n\u001b[0;32m----> 2\u001b[0m trainer\u001b[39m.\u001b[39;49mtrain(training_loader, testing_loader)\n",
            "File \u001b[0;32m~/opt/miniconda3/lib/python3.9/site-packages/trphysx-0.0.8-py3.9.egg/trphysx/embedding/training/enn_trainer.py:109\u001b[0m, in \u001b[0;36mEmbeddingTrainer.train\u001b[0;34m(self, training_loader, eval_dataloader)\u001b[0m\n\u001b[1;32m    106\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel\u001b[39m.\u001b[39mclear_gradients()\n\u001b[1;32m    107\u001b[0m \u001b[39mfor\u001b[39;00m mbidx, inputs \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(training_loader):\n\u001b[0;32m--> 109\u001b[0m     loss0, loss_reconstruct0 \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel(\u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49minputs)\n\u001b[1;32m    110\u001b[0m     loss0 \u001b[39m=\u001b[39m loss0\u001b[39m.\u001b[39msum()\n\u001b[1;32m    112\u001b[0m     loss_reconstruct \u001b[39m=\u001b[39m loss_reconstruct \u001b[39m+\u001b[39m loss_reconstruct0\u001b[39m.\u001b[39msum()\n",
            "File \u001b[0;32m~/opt/miniconda3/lib/python3.9/site-packages/paddle/fluid/dygraph/layers.py:948\u001b[0m, in \u001b[0;36mLayer.__call__\u001b[0;34m(self, *inputs, **kwargs)\u001b[0m\n\u001b[1;32m    945\u001b[0m \u001b[39mif\u001b[39;00m (\u001b[39mnot\u001b[39;00m in_declarative_mode()) \u001b[39mand\u001b[39;00m (\u001b[39mnot\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks) \\\n\u001b[1;32m    946\u001b[0m     \u001b[39mand\u001b[39;00m (\u001b[39mnot\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_post_hooks) \u001b[39mand\u001b[39;00m (\u001b[39mnot\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_built) \u001b[39mand\u001b[39;00m in_dygraph_mode() \u001b[39mand\u001b[39;00m (\u001b[39mnot\u001b[39;00m in_profiler_mode()):\n\u001b[1;32m    947\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_build_once(\u001b[39m*\u001b[39minputs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[0;32m--> 948\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mforward(\u001b[39m*\u001b[39;49minputs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m    949\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m    950\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dygraph_call_func(\u001b[39m*\u001b[39minputs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n",
            "File \u001b[0;32m~/opt/miniconda3/lib/python3.9/site-packages/trphysx-0.0.8-py3.9.egg/trphysx/embedding/embedding_lorenz.py:227\u001b[0m, in \u001b[0;36mLorenzEmbeddingTrainer.forward\u001b[0;34m(self, states)\u001b[0m\n\u001b[1;32m    224\u001b[0m xin0 \u001b[39m=\u001b[39m states[:, t0, :]\n\u001b[1;32m    225\u001b[0m _, xRec1 \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39membedding_model(xin0)\n\u001b[0;32m--> 227\u001b[0m g1Pred \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49membedding_model\u001b[39m.\u001b[39;49mkoopmanOperation(g1_old)\n\u001b[1;32m    228\u001b[0m xgRec1 \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39membedding_model\u001b[39m.\u001b[39mrecover(g1Pred)\n\u001b[1;32m    229\u001b[0m xin0 \u001b[39m=\u001b[39m paddle\u001b[39m.\u001b[39mto_tensor(xin0, dtype\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mfloat32\u001b[39m\u001b[39m\"\u001b[39m)\n",
            "File \u001b[0;32m~/opt/miniconda3/lib/python3.9/site-packages/trphysx-0.0.8-py3.9.egg/trphysx/embedding/embedding_lorenz.py:142\u001b[0m, in \u001b[0;36mLorenzEmbedding.koopmanOperation\u001b[0;34m(self, g)\u001b[0m\n\u001b[1;32m    140\u001b[0m kMatrix \u001b[39m=\u001b[39m paddle\u001b[39m.\u001b[39mto_tensor(paddle\u001b[39m.\u001b[39mzeros((\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mobsdim, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mobsdim)))\n\u001b[1;32m    141\u001b[0m \u001b[39m# Populate the off diagonal terms\u001b[39;00m\n\u001b[0;32m--> 142\u001b[0m kMatrix[\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mxidx, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49myidx] \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mkMatrixUT\n\u001b[1;32m    143\u001b[0m kMatrix[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39myidx, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mxidx] \u001b[39m=\u001b[39m \u001b[39m-\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mkMatrixUT\n\u001b[1;32m    145\u001b[0m \u001b[39m# Populate the diagonal\u001b[39;00m\n",
            "File \u001b[0;32m~/opt/miniconda3/lib/python3.9/site-packages/paddle/fluid/dygraph/varbase_patch_methods.py:793\u001b[0m, in \u001b[0;36mmonkey_patch_varbase.<locals>.__setitem__\u001b[0;34m(self, item, value)\u001b[0m\n\u001b[1;32m    788\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mFalse\u001b[39;00m\n\u001b[1;32m    790\u001b[0m \u001b[39mif\u001b[39;00m contain_tensor_or_list(item) \u001b[39mand\u001b[39;00m \u001b[39mnot\u001b[39;00m is_combine_index(item):\n\u001b[1;32m    791\u001b[0m     \u001b[39m# To reuse code with static graph,\u001b[39;00m\n\u001b[1;32m    792\u001b[0m     \u001b[39m# Call _setitem_impl_ when item contains tensor or list.\u001b[39;00m\n\u001b[0;32m--> 793\u001b[0m     \u001b[39mreturn\u001b[39;00m _setitem_impl_(\u001b[39mself\u001b[39;49m, item, value)\n\u001b[1;32m    795\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m    796\u001b[0m     \u001b[39mif\u001b[39;00m framework\u001b[39m.\u001b[39m_in_eager_mode_:\n",
            "File \u001b[0;32m~/opt/miniconda3/lib/python3.9/site-packages/paddle/fluid/variable_index.py:685\u001b[0m, in \u001b[0;36m_setitem_impl_\u001b[0;34m(var, item, value)\u001b[0m\n\u001b[1;32m    681\u001b[0m     \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(slice_info\u001b[39m.\u001b[39mindexes) \u001b[39m!=\u001b[39m \u001b[39mlen\u001b[39m(item):\n\u001b[1;32m    682\u001b[0m         \u001b[39mraise\u001b[39;00m \u001b[39mIndexError\u001b[39;00m(\n\u001b[1;32m    683\u001b[0m             \u001b[39m\"\u001b[39m\u001b[39mValid index accept int or slice or ellipsis or list, but received \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m    684\u001b[0m             \u001b[39m.\u001b[39mformat(item))\n\u001b[0;32m--> 685\u001b[0m     \u001b[39mreturn\u001b[39;00m slice_info\u001b[39m.\u001b[39;49mset_item(var, value)\n\u001b[1;32m    686\u001b[0m attrs \u001b[39m=\u001b[39m {\n\u001b[1;32m    687\u001b[0m     \u001b[39m'\u001b[39m\u001b[39maxes\u001b[39m\u001b[39m'\u001b[39m: axes,\n\u001b[1;32m    688\u001b[0m     \u001b[39m'\u001b[39m\u001b[39mstarts\u001b[39m\u001b[39m'\u001b[39m: starts,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    692\u001b[0m     \u001b[39m'\u001b[39m\u001b[39mnone_axes\u001b[39m\u001b[39m'\u001b[39m: none_axes\n\u001b[1;32m    693\u001b[0m }\n\u001b[1;32m    695\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39m.\u001b[39;00m\u001b[39mlayers\u001b[39;00m \u001b[39mimport\u001b[39;00m utils\n",
            "File \u001b[0;32m~/opt/miniconda3/lib/python3.9/site-packages/paddle/fluid/variable_index.py:197\u001b[0m, in \u001b[0;36mSliceInfo.set_item\u001b[0;34m(self, tensor_origin, value)\u001b[0m\n\u001b[1;32m    195\u001b[0m \u001b[39mif\u001b[39;00m tensor_type \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m    196\u001b[0m     out \u001b[39m=\u001b[39m out\u001b[39m.\u001b[39mastype(tensor_type)\n\u001b[0;32m--> 197\u001b[0m tensor_origin[:] \u001b[39m=\u001b[39m out\u001b[39m.\u001b[39mreshape(tensor_origin\u001b[39m.\u001b[39mshape)\n\u001b[1;32m    199\u001b[0m \u001b[39mreturn\u001b[39;00m tensor_origin\n",
            "File \u001b[0;32m~/opt/miniconda3/lib/python3.9/site-packages/paddle/fluid/dygraph/varbase_patch_methods.py:797\u001b[0m, in \u001b[0;36mmonkey_patch_varbase.<locals>.__setitem__\u001b[0;34m(self, item, value)\u001b[0m\n\u001b[1;32m    795\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m    796\u001b[0m     \u001b[39mif\u001b[39;00m framework\u001b[39m.\u001b[39m_in_eager_mode_:\n\u001b[0;32m--> 797\u001b[0m         \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__setitem_eager_tensor__(item, value)\n\u001b[1;32m    798\u001b[0m     \u001b[39melse\u001b[39;00m:\n\u001b[1;32m    799\u001b[0m         \u001b[39m# Call c++ func __setitem_varbase__ to speedup.\u001b[39;00m\n\u001b[1;32m    800\u001b[0m         \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__setitem_varbase__(item, value)\n",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
          ]
        }
      ],
      "source": [
        "trainer = EmbeddingTrainer(model, args, (optimizer, scheduler))\n",
        "trainer.train(training_loader, testing_loader)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h3ZEE4ixh8nr"
      },
      "source": [
        "Check your Google drive for checkpoints."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "train_lorenz_enn.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "base",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.12"
    },
    "vscode": {
      "interpreter": {
        "hash": "db64661fe0d122184d8f0ece1104b953994d7acbc475dabbdd4eb4bc24907e06"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
